import pandas as p
import numpy as np
import sys
import os
#Class that is used to collect inflation information
class Inflator(object):
    def __init__(self):
        #Establish files
        if(sys.platform=='darwin'):
            self.allUnadjustedFile = r'PythonScripts/Data/CPI Data/AllItemsUnadjustedUpdated-2022.xlsx'
            self.gdpInflatorFile = 'PythonScripts/Data/BEA GDP Inflator/Section1all_xls-2022.xlsx'
        else:
            self.allUnadjustedFile = r'PythonScripts\Data\CPI Data\AllItemsUnadjustedUpdated-2022.xlsx'
            self.gdpInflatorFile = r'PythonScripts\Data\BEA GDP Inflator\Section1all_xls-2022.xlsx'

        #Clean and store data
        self.allUnadjustedData = self.cleanData(self.allUnadjustedFile)
        self.gdpInflatorData = self.cleanDataBEA(self.gdpInflatorFile)

    # Inflate a given price using a CPI for all items
    def inflateAll(self, price, year, to):
        if(year in self.allUnadjustedData.keys() and to in self.allUnadjustedData.keys()):
            yearInd = self.allUnadjustedData[year]
            toInd = self.allUnadjustedData[to]

            newPrice = price*(toInd/yearInd)

            return newPrice
        else:
            print('One of the years are not in range of keys.')
    #Inflate a given price using a GDP deflator
    def inflateGDP(self, price, year, to):
        if (year in self.gdpInflatorData.keys() and to in self.gdpInflatorData.keys()):
            yearInd = self.gdpInflatorData[year]
            toInd = self.gdpInflatorData[to]

            newPrice = price * (toInd / yearInd)

            return newPrice
        else:
            print('One of the years are not in range of keys.')
    #Clean data from BLS
    def cleanData(self, file):
        print(file)
        preData = p.read_excel(file)

        years = list(preData.iloc[11:, 0])
        indices = list(preData.iloc[11:, 13])

        r = dict(zip(years, indices))

        return r
    #Clean data from the BEA
    def cleanDataBEA(self, file, sheet='T10109-A'):
        preData = p.read_excel(file, sheet, skiprows=7)
        yearStartCol = list(preData.columns).index('1929')

        indices = np.array(list(preData.iloc[0,yearStartCol:])).astype(float)
        years = np.array(list(preData.columns)[yearStartCol:]).astype(int)

        r = dict(zip(years, indices))

        return r